Skip to content

[PD] prepare request in prefill instance by multi threads#7724

Merged
Jiang-Jia-Jun merged 1 commit into
PaddlePaddle:release/2.6from
juncaipeng:refine-pd-fetch-req
May 13, 2026
Merged

[PD] prepare request in prefill instance by multi threads#7724
Jiang-Jia-Jun merged 1 commit into
PaddlePaddle:release/2.6from
juncaipeng:refine-pd-fetch-req

Conversation

@juncaipeng
Copy link
Copy Markdown
Collaborator

@juncaipeng juncaipeng commented May 6, 2026

Motivation

  • 使用多线程并发准备 prefill 实例中的请求,减少请求准备阶段的串行等待时间,提升吞吐性能。原实现通过单线程顺序执行 preallocate_resource_in_p → send_splitwise_tasks → check_decode_allocated,在 D 侧响应延迟较高时会阻塞后续请求的准备;新实现将请求准备逻辑拆分为独立的 prepare 线程池(默认 5 线程),与 schedule 线程并行运行,同时移除了旧的 finished_add_cache_task 同步机制,由 cache_messager 侧主动轮询等待。
  • 多p多d下避免因为一个d实例资源不足导致p实例吞吐下降

Modifications

  • fastdeploy/engine/common_engine_prepare_mixin.py(新增):抽取 EngineServicePrepareMixin,实现 _fetch_request_mixed_fetch_request_prefill_fetch_request_decode 三种角色的请求准备逻辑,以及 _fetch_loop worker 线程和 _prepare_request_v1 入口。
  • fastdeploy/engine/common_engine.pyEngineService 继承 EngineServicePrepareMixin;在 ENABLE_V1_KVCACHE_SCHEDULER 模式下新增 prepare_request_threadschedule_request_thread 并行运行;移除原内联 ThreadPoolExecutor_fetch_request 闭包。
  • fastdeploy/envs.py:新增 FD_PREFILL_PREPARE_REQ_THREAD_NUM(默认 5)控制 prefill 实例的请求准备线程数。
  • fastdeploy/inter_communicator/engine_worker_queue.py:删除 finished_add_cache_task 相关的队列、锁、barrier 及标志位,移除 put_finished_add_cache_task_req / get_finished_add_cache_task_req 方法。
  • fastdeploy/cache_manager/cache_messager.py:新增 _maybe_wait_for_cache_task 轮询等待方法;在 prefill_layerwise_send_cache_thread 中调用该方法主动等待 cache task 就绪;移除旧的 finish_add_cache_task_barrier 同步点。

Usage or Command

# 通过环境变量控制 prefill 请求准备线程数(默认 5)
export FD_PREFILL_PREPARE_REQ_THREAD_NUM=5

Tests

性能测试条件:EB45 0.3B模型、1P1D、1000条请求、并发256、输入平均1280、输出平均700

  • 原始性能:275s
  • 并发数为1的性能:263s
  • 并发数为5的性能:260s

N/A(本次变更为性能优化,不影响模型计算结果)

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings May 6, 2026 08:55
@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 6, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在在 PD disaggregation 场景下,将 prefill 实例的 request 获取/准备(prepare)阶段从原先与调度线程耦合的实现中拆分出来,并尝试通过多线程并发 fetch/prepare 来提升吞吐;同时简化了 engine 与 cache_messager 之间关于“add cache task 完成”的同步机制,并新增了相关环境变量开关。

Changes:

  • ENABLE_V1_KVCACHE_SCHEDULER 下的“取请求+准备请求”逻辑抽到 EngineServicePrepareMixin,并在 EngineService.start() 中新增 prepare_request_threadschedule_request_thread 两条线程分工协作。
  • 移除 EngineWorkerQueuefinished_add_cache_task_* 相关队列/锁/manager register 及 put/get_finished_add_cache_task_req 接口,cache_messager 侧改为在发送缓存前按需等待任务就绪。
  • 新增环境变量 FD_PREFILL_FETCH_THREAD_NUM 控制 prefill 实例的 fetch worker 线程数。

PR 标题/描述需要补充:

  • 标题未按仓库约定携带标签(形如[Optimization]... / [PD Disaggregation]...),建议改为例如:[PD Disaggregation][Optimization] Prepare requests in prefill with multi-thread fetching
  • PR 描述模板未填写 Motivation / Modifications / Usage / Accuracy Tests 等关键信息,建议至少补充:为何需要多线程、预期收益/风险、如何验证(benchmark/trace/压测指标)、以及为何无需/如何补充单测。
  • 新增环境变量与行为变更建议同步更新文档(如 environment_variables.md/相关使用说明)。

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
fastdeploy/inter_communicator/engine_worker_queue.py 移除“add cache task 完成”相关共享队列/锁/注册与接口,简化 e2w 同步面。
fastdeploy/envs.py 新增 FD_PREFILL_FETCH_THREAD_NUM 环境变量读取。
fastdeploy/engine/common_engine.py EngineService 引入 prepare mixin;拆分 prepare/schedule 两条线程;退出时增加 fetch pool shutdown。
fastdeploy/engine/common_engine_prepare_mixin.py 新增 mixin:实现 mixed/prefill/decode 三种角色的 fetch/prepare,并用 ThreadPoolExecutor 支持 prefill 多线程 fetch。
fastdeploy/cache_manager/cache_messager.py 移除对 finished_add_cache_task 队列/屏障的依赖;新增等待 cache task 到达的逻辑;补充/调整部分注释与日志。
Comments suppressed due to low confidence (1)

fastdeploy/inter_communicator/engine_worker_queue.py:100

  • 该 PR 移除了 finished_add_cache_task_* 相关队列/锁/接口(例如 put/get_finished_add_cache_task_req 及对应 manager register)。当前仓库测试里仍有多处直接引用这些接口(如 tests/inter_communicator/test_e2w_queue.py、tests/engine/test_common_engine.py、tests/cache_manager/test_cache_messager.py),会导致单测失败或下游代码调用报错。建议同步更新测试与所有调用方;如需兼容旧行为,可保留接口但实现为空/弃用并在日志中提示迁移。
            ]
            self.connected_client_counter_init: List[Value] = [
                Value("i", 0) for _ in range(self.local_data_parallel_size)
            ]
            self.finished_req_list = [list() for _ in range(self.local_data_parallel_size)]
            self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
            self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
            self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
            self.client_read_info_flag_init: List[List[int]] = [
                [0] * self.num_client for _ in range(self.local_data_parallel_size)
            ]

Comment thread fastdeploy/engine/common_engine_prepare_mixin.py Outdated
Comment thread fastdeploy/engine/common_engine_prepare_mixin.py
Comment thread fastdeploy/envs.py
Comment thread fastdeploy/cache_manager/cache_messager.py
Comment thread fastdeploy/cache_manager/cache_messager.py Outdated
PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 6, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-12 13:39:11

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

存在 1 个 Required 任务失败(Approval),需优先处理;另有 3 个 Required 任务仍在运行中,CI 尚未全部完成。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
36(0) 36 28 4 3 1 0

2 任务状态汇总

2.1 Required任务 : 6/10 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Approval 8s PR问题:修改envs.py未获审批,Cherry-Pick标题格式违规 请指定RD审批;按Cherry-Pick规范修正PR标题 Job -
Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage - 运行中 - Job -
Extracted partial CE model tasks to run in CI. / run_ce_cases - 运行中 - Job -
xpu_4cards_case_test / run_xpu_4cards_cases - 运行中 - Job -
其余 6 个必选任务通过 - - - - -

2.2 可选任务 — 22/26 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Run iluvatar Tests / run_iluvatar_cases 15m55s Job -
Check PR Template 15s Job -
Trigger Jenkins for PR 50s Job -
⏸️ CI_HPU - - -
其余 22 个可选任务通过 - - -

3 失败详情(仅 required)

Approval — 流程审批(置信度: 高)

Approval

  • 状态: ❌ 失败
  • 错误类型: 流程审批
  • 置信度: 高
  • 根因摘要: PR修改envs.py未获指定RD审批,且Cherry-Pick PR标题格式不符合规范
  • 分析器: 通用分析(fallback)

根因详情:
scripts/check_approval.sh 检测到 2 个审批错误:

  1. PR 修改了 fastdeploy/envs.py,需要 jiangjiajun/liuyuanle/chenjian26/wanglongzhi 中至少一人完成 Approve。
  2. 该 PR 目标分支为 release/2.6(release 分支),按规范此类 PR 应为 Cherry-Pick,标题须含 [Cherry-Pick] 及原 develop PR 编号,并需 qingqing01/jiangjiajun/heavengate 审批;当前标题 [PD] prepare request in prefill instance by multi threads 不符合该规范。

关键日志:

0. You must have one FastDeploy RD (jiangjiajun/liuyuanle/chenjian26/wanglongzhi2001) approval for modifying [fastdeploy/envs.py].
1. Cherry-Pick PR must come from develop and the title must contain [Cherry-Pick] and the original develop PR number.
There are 2 approved errors.
Process completed with exit code 6.

修复建议:

  1. 请联系 jiangjiajun/liuyuanle/chenjian26/wanglongzhi 中至少一人对本 PR 进行 Approve(覆盖 fastdeploy/envs.py 修改要求)
  2. 若本 PR 属于 Cherry-Pick,请将标题改为 [Cherry-Pick][PD] prepare request in prefill instance by multi threads (#XXXX) 格式(XXXX 为原 develop PR 编号),并请 qingqing01/jiangjiajun/heavengate 审批

修复建议摘要: 请指定RD审批envs.py修改,并按Cherry-Pick规范修正PR标题

关联变更: PR 修改了 fastdeploy/envs.py(触发审批要求);目标分支为 release/2.6(触发 Cherry-Pick 规范检查)

链接: 查看日志

PaddlePaddle-bot

This comment was marked as outdated.

@juncaipeng juncaipeng force-pushed the refine-pd-fetch-req branch from 7f32773 to 3814451 Compare May 6, 2026 09:43
Copilot AI review requested due to automatic review settings May 6, 2026 10:04
@juncaipeng juncaipeng force-pushed the refine-pd-fetch-req branch from 3814451 to b2ff318 Compare May 6, 2026 10:04
@juncaipeng juncaipeng changed the title prepare request in prefill instance by multi threads [PD] prepare request in prefill instance by multi threads May 6, 2026
PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.

Comment thread fastdeploy/inter_communicator/engine_worker_queue.py
Comment thread fastdeploy/envs.py
Comment thread fastdeploy/engine/common_engine_prepare_mixin.py
Comment thread fastdeploy/splitwise/splitwise_connector.py
Comment thread fastdeploy/cache_manager/cache_messager.py
PaddlePaddle-bot

This comment was marked as outdated.

@juncaipeng juncaipeng force-pushed the refine-pd-fetch-req branch from b2ff318 to ccef373 Compare May 7, 2026 07:07
Copilot AI review requested due to automatic review settings May 12, 2026 04:06
@juncaipeng juncaipeng force-pushed the refine-pd-fetch-req branch from ccef373 to c9caba4 Compare May 12, 2026 04:06
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 9 changed files in this pull request and generated 3 comments.

Comment thread fastdeploy/engine/common_engine_prepare_mixin.py
Comment thread fastdeploy/cache_manager/cache_messager.py
Comment thread tests/cache_manager/test_cache_messager.py
PaddlePaddle-bot

This comment was marked as outdated.

block_start_end_list = []
current_prefilled_token_num_list = []
for engine_index, current_step_prefilled_token_num in batch_engine_signals:
self._maybe_wait_for_cache_task(engine_index)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果引擎收到请求并推理了、但是cache messager还没收到请求,这里就等待收到请求,避免错误。如果万一收不到请求就hang住,避免出现传输cache错误。

Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
"""
tracing.trace_set_thread_info("Scheduler Task to Work")
get_request_pool = ThreadPoolExecutor(max_workers=1)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

将准备请求相关函数移到单独一个文件中

Value("i", 0) for _ in range(self.local_data_parallel_size)
]
self.finished_req_list = [list() for _ in range(self.local_data_parallel_size)]
self.finished_add_cache_task_list = [list() for _ in range(self.local_data_parallel_size)]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

移除messager收到cache_task的应答信号,便于支持多并发准备请求

PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-12 13:31:37

📋 Review 摘要

PR 概述:将 prefill 实例中的请求准备逻辑从单线程改为多线程并发执行(默认 5 线程),通过解耦 prepare 与 schedule 线程提升 PD 分离场景下的吞吐性能,同时移除旧的 finished_add_cache_task 同步机制,改由 cache_messager 侧主动轮询等待。
变更范围fastdeploy/engine/(新增 mixin)、fastdeploy/cache_manager/fastdeploy/inter_communicator/fastdeploy/envs.py
影响面 Tag[Engine] [KVCache] [PD Disaggregation]


📝 PR 规范检查

存在两处规范问题:① 标题 Tag [PD] 不在官方列表,PR 目标分支为 release/2.6 应使用 [Cherry-Pick] 格式;② 描述缺少 ## Accuracy Tests 节(当前用 ## Tests 替代,节名需与模板精确匹配)。

标题建议(可直接复制):

  • [Cherry-Pick][PD Disaggregation] prepare request in prefill instance by multi threads (#原PR号)

PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):

## Motivation
* 使用多线程并发准备 prefill 实例中的请求,减少请求准备阶段的串行等待时间,提升吞吐性能。原实现通过单线程顺序执行 preallocate_resource_in_p → send_splitwise_tasks → check_decode_allocated,在 D 侧响应延迟较高时会阻塞后续请求的准备;新实现将请求准备逻辑拆分为独立的 prepare 线程池(默认 5 线程),与 schedule 线程并行运行,同时移除了旧的 finished_add_cache_task 同步机制,由 cache_messager 侧主动轮询等待。
* 多p多d下避免因为一个d实例资源不足导致p实例吞吐下降

## Modifications
- `fastdeploy/engine/common_engine_prepare_mixin.py`(新增):抽取 `EngineServicePrepareMixin`,实现 `_fetch_request_mixed``_fetch_request_prefill``_fetch_request_decode` 三种角色的请求准备逻辑,以及 `_fetch_loop` worker 线程和 `_prepare_request_v1` 入口。
- `fastdeploy/engine/common_engine.py``EngineService` 继承 `EngineServicePrepareMixin`;在 `ENABLE_V1_KVCACHE_SCHEDULER` 模式下新增 `prepare_request_thread``schedule_request_thread` 并行运行;移除原内联 `ThreadPoolExecutor``_fetch_request` 闭包。
- `fastdeploy/envs.py`:新增 `FD_PREFILL_PREPARE_REQ_THREAD_NUM`(默认 5)控制 prefill 实例的请求准备线程数。
- `fastdeploy/inter_communicator/engine_worker_queue.py`:删除 `finished_add_cache_task` 相关的队列、锁、barrier 及标志位,移除 `put_finished_add_cache_task_req` / `get_finished_add_cache_task_req` 方法。
- `fastdeploy/cache_manager/cache_messager.py`:新增 `_maybe_wait_for_cache_task` 轮询等待方法;在 `prefill_layerwise_send_cache_thread` 中调用该方法主动等待 cache task 就绪;移除旧的 `finish_add_cache_task_barrier` 同步点。

## Usage or Command
```bash
# 通过环境变量控制 prefill 请求准备线程数(默认 5)
export FD_PREFILL_PREPARE_REQ_THREAD_NUM=5
```

## Accuracy Tests
性能测试条件:EB45 0.3B模型、1P1D、1000条请求、并发256、输入平均1280、输出平均700
* 原始性能:275s
* 并发数为1的性能:263s
* 并发数为5的性能:260s

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [x] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

问题

级别 文件 概述
🟡 建议 fastdeploy/cache_manager/cache_messager.py:925 _maybe_wait_for_cache_task 无超时上限,D 侧异常时 prefill_layerwise_send_cache_thread 将永久阻塞
❓ 疑问 fastdeploy/engine/common_engine_prepare_mixin.py 多线程并发调用 preallocate_resource_in_p / send_splitwise_tasks 的线程安全性

总体评价

整体重构思路清晰,通过拆分 prepare/schedule 职责实现并发提速,engine_worker_queue 的大规模清理也减少了不必要的同步开销。需关注 _maybe_wait_for_cache_task 潜在的永久阻塞风险,以及多线程并发资源分配路径的线程安全确认。

Comment thread fastdeploy/cache_manager/cache_messager.py
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 58.53659% with 68 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.6@a5191f2). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/engine/common_engine_prepare_mixin.py 56.73% 52 Missing and 9 partials ⚠️
fastdeploy/cache_manager/cache_messager.py 66.66% 4 Missing and 1 partial ⚠️
fastdeploy/engine/common_engine.py 75.00% 2 Missing ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.6    #7724   +/-   ##
==============================================
  Coverage               ?   72.82%           
==============================================
  Files                  ?      379           
  Lines                  ?    53917           
  Branches               ?     8430           
==============================================
  Hits                   ?    39267           
  Misses                 ?    11881           
  Partials               ?     2769           
Flag Coverage Δ
GPU 72.82% <58.53%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Jiang-Jia-Jun Jiang-Jia-Jun merged commit 4e7a46e into PaddlePaddle:release/2.6 May 13, 2026
32 of 38 checks passed
@Jiang-Jia-Jun
Copy link
Copy Markdown
Collaborator

需提到develop

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants